{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "mHDxn9VHjxKn"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "3x19oys5j89H"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hFDUpbtvv_3u"
      },
      "source": [
        "# Save and serialize models with Keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "V94_3U2k9rWV"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/guide/keras/save_and_serialize\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/keras/save_and_serialize.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/keras/save_and_serialize.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/keras/save_and_serialize.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZwiVWAQc9tk7"
      },
      "source": [
        "The first part of this guide covers saving and serialization for Sequential models and models built using the Functional API and for Sequential models. The saving and serialization APIs are the exact same for both of these types of models.\n",
        "\n",
        "Saving for custom subclasses of `Model` is covered in the section \"Saving Subclassed Models\". The APIs in this case are slightly different than for Sequential or Functional models."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "uqSgPMHguAAs"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "bx5w4U5muDAo"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "try:\n",
        "  # %tensorflow_version only exists in Colab.\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "  pass\n",
        "import tensorflow as tf\n",
        "\n",
        "tf.keras.backend.clear_session()  # For easy reset of notebook state."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wwCxkE6RyyPy"
      },
      "source": [
        "## Part I: Saving Sequential models or Functional models\n",
        "\n",
        "Let's consider the following model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ILmySACTvSA9"
      },
      "outputs": [],
      "source": [
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "\n",
        "inputs = keras.Input(shape=(784,), name='digits')\n",
        "x = layers.Dense(64, activation='relu', name='dense_1')(inputs)\n",
        "x = layers.Dense(64, activation='relu', name='dense_2')(x)\n",
        "outputs = layers.Dense(10, activation='softmax', name='predictions')(x)\n",
        "\n",
        "model = keras.Model(inputs=inputs, outputs=outputs, name='3_layer_mlp')\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xPRqbd0yw8hY"
      },
      "source": [
        "Optionally, let's train this model, just so it has weight values to save, as well as an optimizer state.\n",
        "Of course, you can save models you've never trained, too, but obviously that's less interesting."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "gCygTeGQw74g"
      },
      "outputs": [],
      "source": [
        "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
        "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n",
        "x_test = x_test.reshape(10000, 784).astype('float32') / 255\n",
        "\n",
        "model.compile(loss='sparse_categorical_crossentropy',\n",
        "              optimizer=keras.optimizers.RMSprop())\n",
        "history = model.fit(x_train, y_train,\n",
        "                    batch_size=64,\n",
        "                    epochs=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "htnmbhz-iOwh"
      },
      "outputs": [],
      "source": [
        "# Save predictions for future checks\n",
        "predictions = model.predict(x_test)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "opP1KROHwWwd"
      },
      "source": [
        "\n",
        "### Whole-model saving\n",
        "\n",
        "You can save a model built with the Functional API into a single file. You can later recreate the same model from this file, even if you no longer have access to the code that created the model.\n",
        "\n",
        "This file includes:\n",
        "\n",
        "- The model's architecture\n",
        "- The model's weight values (which were learned during training)\n",
        "- The model's training config (what you passed to `compile`), if any\n",
        "- The optimizer and its state, if any (this enables you to restart training where you left off)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "HqHvq6Igw3wx"
      },
      "outputs": [],
      "source": [
        "# Save the model\n",
        "model.save('path_to_my_model.h5')\n",
        "\n",
        "# Recreate the exact same model purely from the file\n",
        "new_model = keras.models.load_model('path_to_my_model.h5')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "mmIcF6UOItJE"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "# Check that the state is preserved\n",
        "new_predictions = new_model.predict(x_test)\n",
        "np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)\n",
        "\n",
        "# Note that the optimizer state is preserved as well:\n",
        "# you can resume training where you left off."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "-WEPW3n8ICyz"
      },
      "source": [
        "### Export to SavedModel\n",
        "\n",
        "You can also export a whole model to the TensorFlow `SavedModel` format. `SavedModel` is a standalone serialization format for TensorFlow objects, supported by TensorFlow serving as well as TensorFlow implementations other than Python."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "cKASRTKCU5nv"
      },
      "outputs": [],
      "source": [
        "# Export the model to a SavedModel\n",
        "model.save('path_to_saved_model', save_format='tf')\n",
        "\n",
        "# Recreate the exact same model\n",
        "new_model = keras.models.load_model('path_to_saved_model')\n",
        "\n",
        "# Check that the state is preserved\n",
        "new_predictions = new_model.predict(x_test)\n",
        "np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)\n",
        "\n",
        "# Note that the optimizer state is preserved as well:\n",
        "# you can resume training where you left off."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "4AWgwkKWIhfj"
      },
      "source": [
        "The `SavedModel` files that were created contain:\n",
        "\n",
        "- A TensorFlow checkpoint containing the model weights.\n",
        "- A `SavedModel` proto containing the underlying TensorFlow graph."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GkY8XP_XxgMI"
      },
      "source": [
        "### Architecture-only saving\n",
        "\n",
        "Sometimes, you are only interested in the architecture of the model, and you don't need to save the weight values or the optimizer. In this case, you can retrieve the \"config\" of the model via the `get_config()` method. The config is a Python dict that enables you to recreate the same model -- initialized from scratch, without any of the information learned previously during training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "yQGGvo2Fw4o-"
      },
      "outputs": [],
      "source": [
        "config = model.get_config()\n",
        "reinitialized_model = keras.Model.from_config(config)\n",
        "\n",
        "# Note that the model state is not preserved! We only saved the architecture.\n",
        "new_predictions = reinitialized_model.predict(x_test)\n",
        "assert abs(np.sum(predictions - new_predictions)) \u003e 0."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "WsNBBvDgxsTS"
      },
      "source": [
        "You can alternatively use `to_json()` from `from_json()`, which uses a JSON string to store the config instead of a Python dict. This is useful to save the config to disk."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "5a0z7_6XxqWV"
      },
      "outputs": [],
      "source": [
        "json_config = model.to_json()\n",
        "reinitialized_model = keras.models.model_from_json(json_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "SGC7R6IIxy0o"
      },
      "source": [
        "### Weights-only saving\n",
        "\n",
        "Sometimes, you are only interested in the state of the model -- its weights values -- and not in the architecture. In this case, you can retrieve the weights values as a list of Numpy arrays via `get_weights()`, and set the state of the model via `set_weights`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "B8tHwEvkxw5E"
      },
      "outputs": [],
      "source": [
        "weights = model.get_weights()  # Retrieves the state of the model.\n",
        "model.set_weights(weights)  # Sets the state of the model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Ydwtw-u2x7xC"
      },
      "source": [
        "You can combine `get_config()`/`from_config()` and `get_weights()`/`set_weights()` to recreate your model in the same state. However, unlike `model.save()`, this will not include the training config and the optimizer. You would have to call `compile()` again before using the model for training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "LWVtuxtrx5lb"
      },
      "outputs": [],
      "source": [
        "config = model.get_config()\n",
        "weights = model.get_weights()\n",
        "\n",
        "new_model = keras.Model.from_config(config)\n",
        "new_model.set_weights(weights)\n",
        "\n",
        "# Check that the state is preserved\n",
        "new_predictions = new_model.predict(x_test)\n",
        "np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)\n",
        "\n",
        "# Note that the optimizer was not preserved,\n",
        "# so the model should be compiled anew before training\n",
        "# (and the optimizer will start from a blank state)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "prk0GzwCyIYy"
      },
      "source": [
        "The save-to-disk alternative to `get_weights()` and `set_weights(weights)`\n",
        "is `save_weights(fpath)` and `load_weights(fpath)`.\n",
        "\n",
        "Here's an example that saves to disk:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "2irLnOUbyHlI"
      },
      "outputs": [],
      "source": [
        "# Save JSON config to disk\n",
        "json_config = model.to_json()\n",
        "with open('model_config.json', 'w') as json_file:\n",
        "    json_file.write(json_config)\n",
        "# Save weights to disk\n",
        "model.save_weights('path_to_my_weights.h5')\n",
        "\n",
        "# Reload the model from the 2 files we saved\n",
        "with open('model_config.json') as json_file:\n",
        "    json_config = json_file.read()\n",
        "new_model = keras.models.model_from_json(json_config)\n",
        "new_model.load_weights('path_to_my_weights.h5')\n",
        "\n",
        "# Check that the state is preserved\n",
        "new_predictions = new_model.predict(x_test)\n",
        "np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)\n",
        "\n",
        "# Note that the optimizer was not preserved."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "KBxcFAPHyYi5"
      },
      "source": [
        "But remember that the simplest, recommended way is just this:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "DE4b3ndNyQh3"
      },
      "outputs": [],
      "source": [
        "model.save('path_to_my_model.h5')\n",
        "del model\n",
        "model = keras.models.load_model('path_to_my_model.h5')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "yKikmbdC3O_i"
      },
      "source": [
        "### Weights-only saving using TensorFlow checkpoints\n",
        "\n",
        "Note that `save_weights` can create files either in the Keras HDF5 format,\n",
        "or in the [TensorFlow Checkpoint format](https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint). The format is inferred from the file extension you provide: if it is \".h5\" or \".keras\", the framework uses the Keras HDF5 format. Anything else defaults to Checkpoint."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "0pYKb6LV3h2l"
      },
      "outputs": [],
      "source": [
        "model.save_weights('path_to_my_tf_checkpoint')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZFwKv6JC3kyu"
      },
      "source": [
        "For total explicitness, the format can be explicitly passed via the `save_format` argument, which can take the value \"tf\" or \"h5\":"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "oN9vOaWU34lA"
      },
      "outputs": [],
      "source": [
        "model.save_weights('path_to_my_tf_checkpoint', save_format='tf')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xXgtNRCSyuIW"
      },
      "source": [
        "## Saving Subclassed Models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "mJqOn0snzCRy"
      },
      "source": [
        "Sequential models and Functional models are datastructures that represent a DAG of layers. As such,\n",
        "they can be safely serialized and deserialized.\n",
        "\n",
        "A subclassed model differs in that it's not a datastructure, it's a piece of code. The architecture of the model\n",
        "is defined via the body of the `call` method. This means that the architecture of the model cannot be safely serialized. To load a model, you'll need to have access to the code that created it (the code of the model subclass). Alternatively, you could be serializing this code as bytecode (e.g. via pickling), but that's unsafe and generally not portable.\n",
        "\n",
        "For more information about these differences, see the article [\"What are Symbolic and Imperative APIs in TensorFlow 2.0?\"](https://medium.com/tensorflow/what-are-symbolic-and-imperative-apis-in-tensorflow-2-0-dfccecb01021)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Pkwyu5dVz12P"
      },
      "source": [
        "Let's consider the following subclassed model, which follows the same structure as the model from the first section:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "4Onp-8rGyeQG"
      },
      "outputs": [],
      "source": [
        "class ThreeLayerMLP(keras.Model):\n",
        "\n",
        "  def __init__(self, name=None):\n",
        "    super(ThreeLayerMLP, self).__init__(name=name)\n",
        "    self.dense_1 = layers.Dense(64, activation='relu', name='dense_1')\n",
        "    self.dense_2 = layers.Dense(64, activation='relu', name='dense_2')\n",
        "    self.pred_layer = layers.Dense(10, activation='softmax', name='predictions')\n",
        "\n",
        "  def call(self, inputs):\n",
        "    x = self.dense_1(inputs)\n",
        "    x = self.dense_2(x)\n",
        "    return self.pred_layer(x)\n",
        "\n",
        "def get_model():\n",
        "  return ThreeLayerMLP(name='3_layer_mlp')\n",
        "\n",
        "model = get_model()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wwT_YoKA0yQW"
      },
      "source": [
        "First of all, *a subclassed model that has never been used cannot be saved*.\n",
        "\n",
        "That's because a subclassed model needs to be called on some data in order to create its weights.\n",
        "\n",
        "Until the model has been called, it does not know the shape and dtype of the input data it should be\n",
        "expecting, and thus cannot create its weight variables. You may remember that in the Functional model from the first section, the shape and dtype of the inputs was specified in advance (via `keras.Input(...)`) -- that's why Functional models have a state as soon as they're instantiated.\n",
        "\n",
        "Let's train the model, so as to give it a state:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "xqP4kIFN0fTZ"
      },
      "outputs": [],
      "source": [
        "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n",
        "x_train = x_train.reshape(60000, 784).astype('float32') / 255\n",
        "x_test = x_test.reshape(10000, 784).astype('float32') / 255\n",
        "\n",
        "model.compile(loss='sparse_categorical_crossentropy',\n",
        "              optimizer=keras.optimizers.RMSprop())\n",
        "history = model.fit(x_train, y_train,\n",
        "                    batch_size=64,\n",
        "                    epochs=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rvGCpyX72HOC"
      },
      "source": [
        "The recommended way to save a subclassed model is to use `save_weights` to create a TensorFlow SavedModel checkpoint, which will contain the value of all variables associated with the model:\n",
        "- The layers' weights\n",
        "- The optimizer's state\n",
        "- Any variables associated with stateful model metrics (if any)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "gMg87Tz01cxQ"
      },
      "outputs": [],
      "source": [
        "model.save_weights('path_to_my_weights', save_format='tf')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "KOKNBojtsl0F"
      },
      "outputs": [],
      "source": [
        "# Save predictions for future checks\n",
        "predictions = model.predict(x_test)\n",
        "# Also save the loss on the first batch\n",
        "# to later assert that the optimizer state was preserved\n",
        "first_batch_loss = model.train_on_batch(x_train[:64], y_train[:64])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "h2PM_PL1SzPo"
      },
      "source": [
        "To restore your model, you will need access to the code that created the model object.\n",
        "\n",
        "Note that in order to restore the optimizer state and the state of any stateful  metric, you should\n",
        "compile the model (with the exact same arguments as before) and call it on some data before calling `load_weights`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "OOSGiSkHTERy"
      },
      "outputs": [],
      "source": [
        "# Recreate the model\n",
        "new_model = get_model()\n",
        "new_model.compile(loss='sparse_categorical_crossentropy',\n",
        "                  optimizer=keras.optimizers.RMSprop())\n",
        "\n",
        "# This initializes the variables used by the optimizers,\n",
        "# as well as any stateful metric variables\n",
        "new_model.train_on_batch(x_train[:1], y_train[:1])\n",
        "\n",
        "# Load the state of the old model\n",
        "new_model.load_weights('path_to_my_weights')\n",
        "\n",
        "# Check that the model state has been preserved\n",
        "new_predictions = new_model.predict(x_test)\n",
        "np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)\n",
        "\n",
        "# The optimizer state is preserved as well,\n",
        "# so you can resume training where you left off\n",
        "new_first_batch_loss = new_model.train_on_batch(x_train[:64], y_train[:64])\n",
        "assert first_batch_loss == new_first_batch_loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2_XaE5Oqv7Rh"
      },
      "source": [
        "You've reached the end of this guide! This covers everything you need to know about saving and serializing models with tf.keras in TensorFlow 2.0."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "save_and_serialize.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
